"""
Based on https://github.com/ikostrikov/pytorch-a2c-ppo-acktr
"""
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from utils import helpers as utl

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class A2C_RL2:
    def __init__(
        self,
        args,
        actor_critic,
        # loss function
        critic_loss_coeff,
        entropy_loss_coeff,
        activity_l2_loss_coeff,
        # optimization
        policy_optimizer,
        policy_eps,
        policy_lr,
        policy_anneal_lr,
        train_steps
    ):
        self.args = args

        # the model
        self.actor_critic = actor_critic

        # loss function
        # coefficients for weighting value loss, entropy loss, and activity L2 loss
        self.critic_loss_coeff = critic_loss_coeff
        self.entropy_loss_coeff = entropy_loss_coeff
        self.activity_l2_loss_coeff = activity_l2_loss_coeff

        # optimizer
        if policy_optimizer == 'adam':
            self.optimizer = optim.Adam(
                actor_critic.parameters(), 
                lr=policy_lr,
                eps=policy_eps
            )
        elif policy_optimizer == 'rmsprop':
            self.optimizer = optim.RMSprop(
                actor_critic.parameters(), 
                policy_lr, 
                eps=policy_eps, 
                alpha=0.99
            )

        # learning rate annealing
        self.lr_scheduler_policy = None
        if policy_anneal_lr:
            lam = lambda f: 1 - f / train_steps
            self.lr_scheduler_policy = optim.lr_scheduler.LambdaLR(
                self.optimizer, lr_lambda=lam)
        

    def get_losses(self, policy_storage):

        # re-build computational graph
        if not self.args.shared_rnn:
            action_logits, state_values, actor_hidden_states, critic_hidden_states = \
                self.actor_critic(
                    curr_states=policy_storage.states_for_policy[:-1,:,:], 
                    prev_actions=policy_storage.actions[:-1,:,:], 
                    prev_rewards=policy_storage.rewards[:-1,:,:],
                    actor_prev_hidden_states=None,
                    critic_prev_hidden_states=None)
        elif self.args.shared_rnn:
            action_logits, state_values, rnn_hidden_states = \
                self.actor_critic(
                    curr_states=policy_storage.states_for_policy[:-1,:,:], 
                    prev_actions=policy_storage.actions[:-1,:,:], 
                    prev_rewards=policy_storage.rewards[:-1,:,:],
                    rnn_prev_hidden_states=None)
        
        # get policy distribution
        action_pd = torch.distributions.Categorical(logits=action_logits)
        policy_entropy = action_pd.entropy()
        # get selected actions
        actions = torch.argmax(policy_storage.actions[1:,:,:], dim=2)
        action_log_probs = action_pd.log_prob(actions)

        # compute advantages
        advantages = (policy_storage.returns - policy_storage.state_values).squeeze()

        self.optimizer.zero_grad()

        # compute actor loss
        actor_loss = -(advantages.detach() * action_log_probs).mean()
        # compute critic loss
        critic_loss = (policy_storage.returns - state_values).pow(2).mean()
        # compute activity L2 loss
        if self.args.policy_use_activity_l2_regularization:
            if not self.args.shared_rnn:
                # by default regularize both actor and critic
                activity_l2_loss = actor_hidden_states.pow(2).mean()\
                     + critic_hidden_states.pow(2).mean()
            elif self.args.shared_rnn:
                activity_l2_loss = rnn_hidden_states.pow(2).mean()
        else:
            activity_l2_loss = torch.from_numpy(np.array([0]))

        # (loss = value loss + action loss + entropy loss + regularization, weighted)
        # give bonus for higher policy entropy
        loss = actor_loss - policy_entropy.mean() * self.entropy_loss_coeff\
                        + critic_loss * self.critic_loss_coeff
        if self.args.policy_use_activity_l2_regularization:
            loss += activity_l2_loss * self.activity_l2_loss_coeff

        return loss, actor_loss, critic_loss, policy_entropy.mean(), activity_l2_loss


    def update_parameters(self, loss):
        
        # zero out the gradients
        self.optimizer.zero_grad()
        
        # compute gradients 
        # (will attach to all networks involved in this computation)
        loss.backward()
        nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.args.policy_max_grad_norm)

        # update
        self.optimizer.step()

        if self.lr_scheduler_policy is not None:
            self.lr_scheduler_policy.step()
